{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Graph Classification with DGL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we demonstrate how to use DGL to finish graph classification tasks. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "from dgl.data import TUDataset\n",
    "from dgl.data.utils import split_dataset\n",
    "from dgl.nn.pytorch import conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import BCEWithLogitsLoss\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./asset/enzymes.png\" width=\"500\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we use an enzymes dataset. It constructs graphs from the enzymes based on group functions. Nodes means structure elements and edges means the connections between them. Each graph has a label from 0-5, which means the type of the enzymes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TUDataset(\"ENZYMES\")\n",
    "\n",
    "dataset.graph_labels=torch.tensor(dataset.graph_labels)\n",
    "for i in range(len(dataset)):\n",
    "    dataset[i][0].ndata['node_attr']=(dataset[i][0].ndata['node_attr']).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph, label= dataset[0]\n",
    "print(graph)\n",
    "print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.draw_spring(graph.to_networkx())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split dataset into train and val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset, valset = split_dataset(dataset, [0.8, 0.2], shuffle=True, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DGL could batch multiple small graphs together to accelerate the computation. Detail of batching can be found [here](https://docs.dgl.ai/tutorials/basics/4_batch.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/batch.png\" width=\"500\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_molgraphs_for_classification(data):\n",
    "    \"\"\"Batching a list of datapoints for dataloader in classification tasks.\"\"\"\n",
    "    graphs, labels = map(list, zip(*data))\n",
    "    bg = dgl.batch(graphs)\n",
    "    labels = torch.stack(labels, dim=0)\n",
    "    return bg, labels\n",
    "\n",
    "train_loader = DataLoader(trainset, batch_size=512,\n",
    "                          collate_fn=collate_molgraphs_for_classification)\n",
    "val_loader = DataLoader(valset, batch_size=512,\n",
    "                        collate_fn=collate_molgraphs_for_classification)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Model and Optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we use a two layer Graph Convolutional Network to classify the graphs. Detailed source code can be found [here](https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/classifiers.py#L111)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the similar structure as introduced before, a 3-layer GNN to learn the node-level representations. Then we use built-in readout functions `dgl.sum_nodes`, suming all the node(vertex) representation to get the graph representsions. $$h_g=\\sum{h_v}$$  \n",
    "Then we use a linear(MLP) classifier to classify the graph based on its representation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCNModel(nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_hidden,\n",
    "                 out_feats):\n",
    "        super().__init__()\n",
    "        self.layers = nn.ModuleList([\n",
    "            conv.GraphConv(in_feats, n_hidden, activation=F.relu),\n",
    "            conv.GraphConv(n_hidden, n_hidden, activation=F.relu),\n",
    "            conv.GraphConv(n_hidden, n_hidden, activation=F.relu)\n",
    "        ])\n",
    "        \n",
    "        self.classifier = nn.Linear(n_hidden, out_feats)\n",
    "\n",
    "    def forward(self, g, features):\n",
    "        h = features\n",
    "        for layer in self.layers:\n",
    "            h = layer(g, h)\n",
    "        with g.local_scope():\n",
    "            g.ndata['feat'] = h\n",
    "            h_g = dgl.sum_nodes(g, 'feat')\n",
    "        return self.classifier(h_g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "epochs = 500 if torch.cuda.is_available() else 50\n",
    "model = GCNModel(in_feats=18, n_hidden=64, out_feats=6).to(device)\n",
    "loss_criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = Adam(model.parameters())\n",
    "print(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model.train()\n",
    "for i in range(epochs):\n",
    "    loss_list = []\n",
    "    true_samples = 0\n",
    "    num_samples = 0\n",
    "    for batch_id, batch_data in enumerate(train_loader):\n",
    "        bg, labels = batch_data\n",
    "        atom_feats = bg.ndata.pop('node_attr').float()\n",
    "        atom_feats, labels = atom_feats.to(device), \\\n",
    "                                   labels.to(device).squeeze(-1)\n",
    "        logits = model(bg, atom_feats)\n",
    "        loss = loss_criterion(logits, labels)\n",
    "        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()\n",
    "        num_samples += len(labels)\n",
    "        loss_list.append(loss.item())\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    print(\"Epoch {:05d} | Loss: {:.4f} | Accuracy: {:.4f}\".format(i, np.mean(loss_list), true_samples/num_samples))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "true_samples = 0\n",
    "num_samples = 0\n",
    "with torch.no_grad():\n",
    "    for batch_id, batch_data in enumerate(val_loader):\n",
    "        bg, labels = batch_data\n",
    "        atom_feats = bg.ndata.pop('node_attr')\n",
    "        atom_feats, labels = atom_feats.to(device), \\\n",
    "                                   labels.to(device).squeeze(-1)\n",
    "        logits = model(bg, atom_feats)\n",
    "        logits.argmax()\n",
    "        num_samples += len(labels)\n",
    "        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()\n",
    "print(\"Validation Accuracy: {:.4f}\".format(true_samples/num_samples))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Excercise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There's other built-in readout function, such as `max_nodes` and `mean_nodes`. Docs can be found [here](https://docs.dgl.ai/api/python/batch.html#graph-readout). You can try to replace the `sum_nodes` with other functions to see whether you could acheive better performances.\n",
    "You can also change the network structure as the exercise of `BasicTask.ipynb`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
